# -*- coding: utf-8 -*-
"""
Created on Thu Mar 23 17:14:52 2023

@author: jan
"""


# You have some blue points on a plane.
# Drag to move them. Left-click to create new. Right-click to delete

# Each point have a positive value `a`, which represents how much the point
# "attracts" (or "repulses") newly generated orange points. You can change logartih, of this
# value at the bottom of the screen.
# After each change in poitns, `N=1000` orange points appear. They are created
# as convex combination `d1 X1 + d2 X2 + ... + dn Xn`, where X's are blue 
# points and d's are randomly generated from Dirichlet(a1, a2, ... ,an).



import tkinter as tk
import numpy as np


N = 1000 # number of generated points

class Point:
    def __init__(self, x, y, a=0, color="blue"):
        self.x = x
        self.y = y
        self.radius = 5
        self.color = color
        self.a = a


class CanvasWithPoints:
    def __init__(self, root):
        self.points = []
        self.selected_point = None
        self.canvas = tk.Canvas(root, width=1024, height=786, bg="white")
        self.canvas.pack()
        self.canvas.bind("<Button-1>", self.on_click)
        self.canvas.bind("<B1-Motion>", self.on_drag)
        self.canvas.bind("<ButtonRelease-1>", self.on_release)
        self.canvas.bind("<Button-3>", self.on_right_click)
        
        self.sliders = []
        self.slider_frame = tk.Frame(root)
        self.slider_frame.pack(side=tk.BOTTOM)

    def get_selected_point(self, event):
        for point in self.points:
            if self.is_inside(event, point):
                return point
        return None

    def is_inside(self, event, point):
        x_min = point.x - point.radius
        x_max = point.x + point.radius
        y_min = point.y - point.radius
        y_max = point.y + point.radius
        return x_min <= event.x <= x_max and y_min <= event.y <= y_max

    def add_point(self, x, y):
        self.points.append(Point(x, y))
        self.draw_points()
        self.add_slider()

    def add_slider(self):
        point = self.points[-1]
        slider = tk.Scale(self.slider_frame, from_=5, to=-7, resolution=0.5,
                          orient="vertical", command=lambda value:\
                                                self.update_a(point, float(value)))
        if len(self.points) > 1:
            average_a = (sum([p.a for p in self.points]) - point.a)/(len(self.points) - 1)
            slider.set(average_a)
        else:
            slider.set(point.a)
        
        slider.pack(side=tk.LEFT)
        self.sliders.append(slider)

    def remove_point(self, point):
        index = self.points.index(point)
        self.points.remove(point)
        self.remove_slider(index)
        self.draw_points()

    def remove_slider(self, index):
        slider = self.sliders.pop(index)
        slider.pack_forget()

    def draw_points(self):
        self.canvas.delete("all")
        for point in self.points:
            x1 = point.x - point.radius
            y1 = point.y - point.radius
            x2 = point.x + point.radius
            y2 = point.y + point.radius
            self.canvas.create_oval(x1, y1, x2, y2, fill=point.color)
        self.draw_generated()

    def draw_generated(self):
        P = np.array([(p.x, p.y) for p in self.points])
        A = [p.a for p in self.points]
        D = np.random.dirichlet(np.exp(A), size=N)
        
        generated = D @ P
        
        for p in generated:
            x1 = p[0] - 2
            y1 = p[1] - 2
            x2 = p[0] + 2
            y2 = p[1] + 2
            self.canvas.create_oval(x1, y1, x2, y2, fill="orange")
        

    def on_right_click(self, event):
        point = self.get_selected_point(event)
        if point:
            self.remove_point(point)
            self.draw_points()
    
    def on_click(self, event):
        self.selected_point = self.get_selected_point(event)
        if not self.selected_point:
            self.add_point(event.x, event.y)

    def on_drag(self, event):
        if self.selected_point:
            self.selected_point.x = event.x
            self.selected_point.y = event.y
            self.draw_points()

    def on_release(self, event):
        self.selected_point = None
    
    
    def update_a(self, point, value):
        point.a = float(value)
        self.draw_points()




if __name__ == "__main__":
    root = tk.Tk()
    canvas = CanvasWithPoints(root)
    canvas.add_point(50, 50)
    canvas.add_point(500, 50)
    canvas.add_point(200, 500)
    root.mainloop()